package mongoclient

import (
	"context"

	"go.mongodb.org/mongo-driver/mongo"
	"go.mongodb.org/mongo-driver/mongo/options"
	"gopkg.in/mgo.v2/bson"
)

type Session struct {
	engine     *Engine
	database   string
	collection *mongo.Collection
	statement  *Statement
	lastQuery  string
}
type MongoTable interface {
	TableName() string
}

func (eg *Engine) NewSession() *Session {
	return &Session{
		engine:    eg,
		statement: NewStatement(),
		database:  eg.database,
	}
}

func (s *Session) WithCollection(collectionName string) *Session {
	s.collection = s.engine.Client.Database(s.database).Collection(collectionName)
	return s
}

//创建记录
func (s *Session) InsertOne(ctx context.Context, data interface{}) (interface{}, error) {
	insertResult, err := s.collection.InsertOne(ctx, data)
	if err != nil {
		return nil, err
	}
	return insertResult.InsertedID, nil
}

//创建记录
func (s *Session) InsertMany(ctx context.Context, data []interface{}) ([]interface{}, error) {
	insertManyResult, err := s.collection.InsertMany(context.TODO(), data)
	if err != nil {
		return nil, err
	}
	return insertManyResult.InsertedIDs, nil
}

//更新记录
func (s *Session) Update(ctx context.Context, update interface{}) (updateResult *mongo.UpdateResult, err error) {
	updateData := bson.M{}
	updateData["$set"] = update
	updateResult, err = s.collection.UpdateMany(ctx, s.statement.cond, updateData)
	return
}

//删除单条记录
func (s *Session) DelOne(ctx context.Context) (delResult *mongo.DeleteResult, err error) {
	delResult, err = s.collection.DeleteOne(ctx, s.statement.cond)
	return
}

//删除多条记录
func (s *Session) DelMany(ctx context.Context) (delResult *mongo.DeleteResult, err error) {
	delResult, err = s.collection.DeleteMany(ctx, s.statement.cond)
	return
}
func (s *Session) Where(filed string, value interface{}, operator ...string) *Session {
	s.statement.Where(filed, value, operator...)
	return s
}
func (s *Session) AndWhere(filed string, value interface{}, operator ...string) *Session {
	s.Where(filed, value, operator...)
	return s
}
func (s *Session) OrWhere(members ...OrMember) *Session {
	s.statement.OrWhere(members...)
	return s
}
func (s *Session) WhereIn(field string, data []interface{}) *Session {
	s.statement.WhereIn(field, data)
	return s
}

func (s *Session) Skip(num int64) *Session {
	s.statement.Skip(num)
	return s
}
func (s *Session) Order(order interface{}) *Session {
	s.statement.Order(order)
	return s
}
func (s *Session) Limit(num int64) *Session {
	s.statement.Limit(num)
	return s
}
func (s *Session) FindAll(ctx context.Context) (cur *mongo.Cursor, err error) {
	option := s.getOptions()
	cur, err = s.collection.Find(ctx, s.statement.cond, option)
	return
}
func (s *Session) getOptions() (option *options.FindOptions) {
	option = options.Find()
	if s.statement.limit != 0 {
		option.SetLimit(s.statement.limit)
		option.SetSkip(s.statement.skip)
	}
	if s.statement.order != nil {
		option.SetSort(s.statement.order)
	}
	return
}

//获取结果集返回结构体数组
func (s *Session) Find(ctx context.Context) (resultList []interface{}, err error) {
	option := s.getOptions()
	cur, err := s.collection.Find(ctx, s.statement.cond, option)
	defer cur.Close(ctx)

	for cur.Next(ctx) {
		var result = make(map[string]interface{})
		err := cur.Decode(result)
		if err != nil {
			break
		}
		resultList = append(resultList, result)
	}
	if err = cur.Err(); err != nil {
		return
	}
	return
}

func (s *Session) Row(ctx context.Context, result MongoTable) {
	option := options.FindOne()
	if s.statement.limit != 0 {
		option.SetSkip(s.statement.skip)
	}
	res := s.collection.FindOne(ctx, s.statement.cond, option)
	res.Decode(result)
	return
}

//获取map用来解析数据结构
func (s *Session) GetTableStructData(ctx context.Context) (result map[string]interface{}) {
	option := options.FindOne()
	res := s.collection.FindOne(ctx, s.statement.cond, option)
	res.Decode(&result)
	return
}
